import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml, fetch_covtype

# Datasets we will support in experiments (names mapped to OpenML IDs or fetchers)
DATASETS = {
    "iris": {"fetch": {"name": "iris", "version": 1}},
    "adult": {"fetch": {"name": "adult", "version": 2}},
    "covtype": {"fetch_covtype": {}},
    "credit-g": {"fetch": {"data_id": 31}},
    # Known common alias for Bank Marketing
    "bank-marketing": {"fetch": {"name": "bank-marketing"}},
    # Fallback alias used in some papers
    "bank-marketing-1": {"fetch": {"data_id": 1461}},
}


def load_dataset_safely(name: str):
    """Load dataset and split into train/val/test dict compatible with PipelineEnvironment.
    Returns (dataset_dict, info_msg). On error, returns (None, error_msg).
    """
    try:
        spec = DATASETS.get(name)
        if spec is None:
            return None, f"Unknown dataset: {name}"

        # Fetch via fetch_covtype
        if "fetch_covtype" in spec:
            data = fetch_covtype(as_frame=True)
            X = data.data.copy()
            y = data.target.copy()
        else:
            # Use fetch_openml
            fetch_args = spec["fetch"]
            data = fetch_openml(as_frame=True, **fetch_args)
            X = data.data.copy()
            y = data.target.copy()

        # Clean common missing value sentinels
        # Replace string placeholders like '?' with NaN (e.g., Adult)
        if isinstance(X, pd.DataFrame):
            X.replace("?", np.nan, inplace=True)
            # Strip whitespace from object columns to reduce spurious categories
            for col in X.select_dtypes(include=["object"]).columns:
                X[col] = X[col].astype(str).str.strip()

        # Encode y if categorical/object
        if y.dtype == "object" or getattr(y.dtype, "name", "") == "category":
            from sklearn.preprocessing import LabelEncoder
            y = LabelEncoder().fit_transform(y)

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42
        )
        X_train, X_val, y_train, y_val = train_test_split(
            X_train, y_train, test_size=0.2, random_state=42
        )

        ds = {
            "X_train": X_train,
            "X_test": X_test,
            "X_val": X_val,
            "y_train": y_train,
            "y_test": y_test,
            "y_val": y_val,
            "feature_names": getattr(X, "columns", None),
            "n_classes": int(len(np.unique(y)))
        }
        return ds, f"Loaded dataset '{name}' with shape {X.shape}"
    except Exception as e:
        return None, f"Failed to load dataset '{name}': {e}"


def seed_everything(seed: int = 42):
    import os, random
    import numpy as _np
    import torch
    _np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    try:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass


def safe_import(package_name):
    try:
        module = __import__(package_name)
        return module, None
    except Exception as e:
        return None, str(e)
